use std::collections::HashMap;

use serde::Serialize;

use crate::rest::post::{Post, PostNoStream, PostStream};

#[derive(Debug, Serialize, Default, Clone)]
pub struct CompletionRequest {
    /// ID of the model to use. Note that not all models are supported for completion.
    pub model: String,
    /// The prompt(s) to generate completions for, encoded as a string, array of
    /// strings, array of tokens, or array of token arrays.
    /// Note that <|endoftext|> is the document separator that the model sees during
    /// training, so if a prompt is not specified the model will generate as if from the
    /// beginning of a new document.
    pub prompt: Prompt,
    /// Generates `best_of` completions server-side and returns the "best" (the one with
    /// the highest log probability per token). Results cannot be streamed.
    ///
    /// When used with `n`, `best_of` controls the number of candidate completions and
    /// `n` specifies how many to return – `best_of` must be greater than `n`.
    ///
    /// **Note:** Because this parameter generates many completions, it can quickly
    /// consume your token quota. Use carefully and ensure that you have reasonable
    /// settings for `max_tokens` and `stop`.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub best_of: Option<usize>,
    /// Echo back the prompt in addition to the completion
    #[serde(skip_serializing_if = "Option::is_none")]
    pub echo: Option<bool>,
    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
    /// existing frequency in the text so far, decreasing the model's likelihood to
    /// repeat the same line verbatim.
    ///
    /// [more info about frequency/presence penalties](https://platform.openai.com/docs/guides/text-generation)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub frequency_penalty: Option<f32>,
    /// Modify the likelihood of specified tokens appearing in the completion.
    ///
    /// Accepts a JSON object that maps tokens (specified by their token ID in the GPT
    /// tokenizer) to an associated bias value from -100 to 100. You can use this
    /// [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs.
    /// Mathematically, the bias is added to the logits generated by the model prior to
    /// sampling. The exact effect will vary per model, but values between -1 and 1
    /// should decrease or increase likelihood of selection; values like -100 or 100
    /// should result in a ban or exclusive selection of the relevant token.
    ///
    /// As an example, you can pass `{"50256": -100}` to prevent the <|end-of-stream|> token
    /// from being generated.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logit_bias: Option<HashMap<String, isize>>,
    /// Include the log probabilities on the `logprobs` most likely output tokens, as
    /// well the chosen tokens. For example, if `logprobs` is 5, the API will return a
    /// list of the 5 most likely tokens. The API will always return the `logprob` of
    /// the sampled token, so there may be up to `logprobs+1` elements in the response.
    ///
    /// The maximum value for `logprobs` is 5.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logprobs: Option<usize>,
    /// The maximum number of [tokens](/tokenizer) that can be generated in the
    /// completion.
    ///
    /// The token count of your prompt plus `max_tokens` cannot exceed the model's
    /// context length.
    /// [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
    /// for counting tokens.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<usize>,
    /// How many completions to generate for each prompt.
    ///
    /// **Note:** Because this parameter generates many completions, it can quickly
    /// consume your token quota. Use carefully and ensure that you have reasonable
    /// settings for `max_tokens` and `stop`.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub n: Option<usize>,
    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
    /// whether they appear in the text so far, increasing the model's likelihood to
    /// talk about new topics.
    ///
    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub presence_penalty: Option<f32>,
    /// If specified, our system will make a best effort to sample deterministically,
    /// such that repeated requests with the same `seed` and parameters should return
    /// the same result.
    ///
    /// Determinism is not guaranteed, and you should refer to the `system_fingerprint`
    /// response parameter to monitor changes in the backend.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub seed: Option<usize>,
    /// Up to 4 sequences where the API will stop generating further tokens. The
    /// returned text will not contain the stop sequence.
    ///
    /// Note: Not supported with latest reasoning models `o3` and `o4-mini`.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stop: Option<StopKeywords>,
    /// Whether to stream back partial progress. If set, tokens will be sent as
    /// data-only
    /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
    /// as they become available, with the stream terminated by a `data: [DONE]`
    /// message.
    /// [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
    pub stream: bool,
    /// Options for streaming response. Only set this when you set `stream: true`.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stream_options: Option<StreamOptions>,
    /// The suffix that comes after a completion of inserted text.
    ///
    /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub suffix: Option<String>,
    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
    /// make the output more random, while lower values like 0.2 will make it more
    /// focused and deterministic.
    ///
    /// It is generally recommended to alter this or `top_p` but not both.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,
    /// An alternative to sampling with temperature, called nucleus sampling,
    /// where the model considers the results of the tokens with `top_p`
    /// probability mass. So 0.1 means only the tokens comprising the top 10%
    /// probability mass are considered.
    ///
    /// It is generally recommended to alter this or `temperature` but not both.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,
    /// A unique identifier representing your end-user, which can help OpenAI to monitor
    /// and detect abuse.
    /// [Learn more from OpenAI](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
    #[serde(skip_serializing_if = "Option::is_none")]
    pub user: Option<String>,
    /// Add additional JSON properties to the request
    pub extra_body: serde_json::Map<String, serde_json::Value>,
}

#[derive(Debug, Serialize, Clone)]
#[serde(untagged)]
pub enum Prompt {
    /// String
    PromptString(String),
    /// Array of strings
    PromptStringArray(Vec<String>),
    /// Array of tokens
    TokensArray(Vec<usize>),
    /// Array of arrays of tokens
    TokenArraysArray(Vec<Vec<usize>>),
}
impl Default for Prompt {
    fn default() -> Self {
        Self::PromptString("".to_string())
    }
}

#[derive(Debug, Serialize, Clone)]
pub struct StreamOptions {
    /// When true, stream obfuscation will be enabled.
    ///
    /// Stream obfuscation adds random characters to an `obfuscation` field on streaming
    /// delta events to normalize payload sizes as a mitigation to certain side-channel
    /// attacks. These obfuscation fields are included by default, but add a small
    /// amount of overhead to the data stream. You can set `include_obfuscation` to
    /// false to optimize for bandwidth if you trust the network links between your
    /// application and the OpenAI API.
    pub include_obfuscation: bool,
    /// If set, an additional chunk will be streamed before the `data: [DONE]` message.
    ///
    /// The `usage` field on this chunk shows the token usage statistics for the entire
    /// request, and the `choices` field will always be an empty array.
    ///
    /// All other chunks will also include a `usage` field, but with a null value.
    /// **NOTE:** If the stream is interrupted, you may not receive the final usage
    /// chunk which contains the total token usage for the request.
    pub include_usage: bool,
}

#[derive(Debug, Serialize, Clone)]
#[serde(untagged)]
pub enum StopKeywords {
    Word(String),
    Words(Vec<String>),
}

impl Post for CompletionRequest {
    fn is_streaming(&self) -> bool {
        self.stream
    }
}

impl PostNoStream for CompletionRequest {
    type Response = super::response::Completion;
}

impl PostStream for CompletionRequest {
    type Response = super::response::Completion;
}

#[cfg(test)]
mod tests {
    use std::sync::LazyLock;

    use futures_util::StreamExt;

    use super::*;

    const QWEN_MODEL: &str = "qwen-coder-turbo-latest";
    const QWEN_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1/completions";
    const QWEN_API_KEY: LazyLock<&'static str> =
        LazyLock::new(|| include_str!("../../keys/modelstudio_domestic_key").trim());

    #[tokio::test]
    async fn test_qwen_completions_no_stream() -> Result<(), anyhow::Error> {
        let request_body = CompletionRequest {
            model: QWEN_MODEL.to_string(),
            prompt: Prompt::PromptString(
                r#"
    package main

    import (
      "fmt"
      "strings"
      "net/http"
      "io/ioutil"
    )

    func main() {

      url := "https://api.deepseek.com/chat/completions"
      method := "POST"

      payload := strings.NewReader(`{
      "messages": [
        {
          "content": "You are a helpful assistant",
          "role": "system"
        },
        {
          "content": "Hi",
          "role": "user"
        }
      ],
      "model": "deepseek-chat",
      "frequency_penalty": 0,
      "max_tokens": 4096,
      "presence_penalty": 0,
      "response_format": {
        "type": "text"
      },
      "stop": null,
      "stream": false,
      "stream_options": null,
      "temperature": 1,
      "top_p": 1,
      "tools": null,
      "tool_choice": "none",
      "logprobs": false,
      "top_logprobs": null
    }`)

      client := &http.Client {
      }
      req, err := http.NewRequest(method, url, payload)

      if err != nil {
        fmt.Println(err)
        return
      }
      req.Header.Add("Content-Type", "application/json")
      req.Header.Add("Accept", "application/json")
      req.Header.Add("Authorization", "Bearer <TOKEN>")

      res, err := client.Do(req)
      if err != nil {
        fmt.Println(err)
        return
      }
      defer res.Body.Close()
"#
                .to_string(),
            ),
            suffix: Some(
                r#"
    if err != nil {
        fmt.Println(err)
        return
    }
    fmt.Println(string(body))
}
"#
                .to_string(),
            ),
            stream: false,
            ..Default::default()
        };

        let result = request_body
            .get_response_string(QWEN_URL, *QWEN_API_KEY)
            .await?;
        println!("{}", result);

        Ok(())
    }

    #[tokio::test]
    async fn test_qwen_completions_stream() -> Result<(), anyhow::Error> {
        let request_body = CompletionRequest {
            model: QWEN_MODEL.to_string(),
            prompt: Prompt::PromptString(
                r#"
        package main

        import (
          "fmt"
          "strings"
          "net/http"
          "io/ioutil"
        )

        func main() {

          url := "https://api.deepseek.com/chat/completions"
          method := "POST"

          payload := strings.NewReader(`{
          "messages": [
            {
              "content": "You are a helpful assistant",
              "role": "system"
            },
            {
              "content": "Hi",
              "role": "user"
            }
          ],
          "model": "deepseek-chat",
          "frequency_penalty": 0,
          "max_tokens": 4096,
          "presence_penalty": 0,
          "response_format": {
            "type": "text"
          },
          "stop": null,
          "stream": true,
          "stream_options": null,
          "temperature": 1,
          "top_p": 1,
          "tools": null,
          "tool_choice": "none",
          "logprobs": false,
          "top_logprobs": null
        }`)

          client := &http.Client {
          }
          req, err := http.NewRequest(method, url, payload)

          if err != nil {
            fmt.Println(err)
            return
          }
          req.Header.Add("Content-Type", "application/json")
          req.Header.Add("Accept", "application/json")
          req.Header.Add("Authorization", "Bearer <TOKEN>")

          res, err := client.Do(req)
          if err != nil {
            fmt.Println(err)
            return
          }
          defer res.Body.Close()
    "#
                .to_string(),
            ),
            suffix: Some(
                r#"
        if err != nil {
            fmt.Println(err)
            return
        }
        fmt.Println(string(body))
    }
    "#
                .to_string(),
            ),
            stream: true,
            ..Default::default()
        };

        let mut stream = request_body
            .get_stream_response_string(QWEN_URL, *QWEN_API_KEY)
            .await?;

        while let Some(chunk) = stream.next().await {
            match chunk {
                Ok(data) => {
                    println!("Received chunk: {:?}", data);
                }
                Err(e) => {
                    eprintln!("Error receiving chunk: {:?}", e);
                    break;
                }
            }
        }

        Ok(())
    }
}
